/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "DHCPProtocol.h"
#include <iomanip>

namespace aiengine {

DHCPProtocol::DHCPProtocol():
	Protocol("DHCP", IPPROTO_UDP) {}

DHCPProtocol::~DHCPProtocol() {

        anomaly_.reset();
}

bool DHCPProtocol::check(const Packet &packet) {

	int length = packet.getLength();

	if (length >= header_size) {
		if ((packet.getSourcePort() == 67)||(packet.getDestinationPort() == 67)) {
			setHeader(packet.getPayload());
			++total_valid_packets_;
			return true;
		}
	}
	++total_invalid_packets_;
	return false;
}

void DHCPProtocol::setDynamicAllocatedMemory(bool value) {

	info_cache_->setDynamicAllocatedMemory(value);
	host_cache_->setDynamicAllocatedMemory(value);
	ip_cache_->setDynamicAllocatedMemory(value);
}

bool DHCPProtocol::isDynamicAllocatedMemory() const {

	return info_cache_->isDynamicAllocatedMemory();
}

uint64_t DHCPProtocol::getCurrentUseMemory() const {

	uint64_t mem = sizeof(DHCPProtocol);

	mem += info_cache_->getCurrentUseMemory();
	mem += host_cache_->getCurrentUseMemory();
	mem += ip_cache_->getCurrentUseMemory();

	return mem;
}

uint64_t DHCPProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(DHCPProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += host_cache_->getAllocatedMemory();
        mem += ip_cache_->getAllocatedMemory();

        return mem;
}

uint64_t DHCPProtocol::getTotalAllocatedMemory() const {

        uint64_t mem = getAllocatedMemory();

        mem += compute_memory_used_by_maps();

        return mem;
}

uint64_t DHCPProtocol::compute_memory_used_by_maps() const {

	uint64_t bytes = 0;

	std::for_each (host_map_.begin(), host_map_.end(), [&bytes] (PairStringCacheHits const &f) {
		bytes += f.first.size();
	});
	std::for_each (ip_map_.begin(), ip_map_.end(), [&bytes] (PairStringCacheHits const &f) {
		bytes += f.first.size();
	});

	return bytes;
}

uint32_t DHCPProtocol::getTotalCacheMisses() const {

	uint32_t miss = 0;

	miss = info_cache_->getTotalFails();
	miss += host_cache_->getTotalFails();
	miss += ip_cache_->getTotalFails();

	return miss;
}

void DHCPProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

                uint64_t total_cache_bytes_released = compute_memory_used_by_maps();
                uint64_t total_bytes_released_by_flows = 0;
                uint64_t total_cache_save_bytes = 0;
                uint32_t release_flows = 0;
                uint32_t release_host = host_map_.size();
		uint32_t release_ips = ip_map_.size();

                for (auto &flow: ft) {
                        if (SharedPointer<DHCPInfo> info = flow->getDHCPInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                ++release_flows;
                                info_cache_->release(info);
                        }
                }

                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: host_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(host_cache_, entry.second.sc);
		}
                host_map_.clear();

                for (auto &entry: ip_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(ip_cache_, entry.second.sc);
		}
                ip_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_host << " host names, " << release_ips << " ips, "  << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
                infoMessage(msg.str());
        }
}

void DHCPProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getDHCPInfo(); info)
		info_cache_->release(info);
}

void DHCPProtocol::attach_host_name(DHCPInfo *info, const boost::string_ref &name) {

        if (!info->host_name) {
                if (auto it = host_map_.find(name); it != host_map_.end()) {
                        ++(it->second).hits;
                        info->host_name = (it->second).sc;
		} else {
                        if (auto host_ptr = host_cache_->acquire(); host_ptr) {
                                host_ptr->name(name.data(), name.length());
                                info->host_name = host_ptr;
                                host_map_.insert(std::make_pair(host_ptr->name(), host_ptr));
                        }
                }
        }
}

void DHCPProtocol::attach_ip(DHCPInfo *info, const boost::string_ref &ip) {

        if (!info->ip) {
                if (auto it = ip_map_.find(ip); it != ip_map_.end()) {
                        ++(it->second).hits;
                        info->ip = (it->second).sc;
		} else {
                        if (auto ip_ptr = ip_cache_->acquire(); ip_ptr) {
                                ip_ptr->name(ip.data(), ip.length());
                                info->ip = ip_ptr;
                                ip_map_.insert(std::make_pair(ip_ptr->name(), ip_ptr));
                        }
                }
        }
}

void DHCPProtocol::handle_request(DHCPInfo *info, const uint8_t *payload, int length) {

        int idx = 0;
        while (idx < length - 4) {
        	short type = payload[idx];
                short len = payload[idx + 1];

                if (type == 12) { // Hostname
			boost::string_ref name(reinterpret_cast<const char*>(&payload[idx + 2]), len);

                        attach_host_name(info, name);
                        break;
		}
                idx += 2 + (int)len;
	}
}

void DHCPProtocol::handle_reply(DHCPInfo *info, const uint8_t *payload, int length) {

        int idx = 0;
        while (idx < length - 8) {
                short type = payload[idx];
                short len = payload[idx + 1];

                if (type == 51) { // IP Lease time
			int32_t lease_time = (payload[idx + 2] << 24) | (payload[idx + 3] << 16) | (payload[idx + 4] << 8) | payload[idx + 5];

			info->setLeaseTime(lease_time);
                        break;
                }
                idx += 2 + (int)len;
        }
}

void DHCPProtocol::handle_ip_address(DHCPInfo *info) {

	in_addr a;
	a.s_addr = header_->yiaddr;
	char *ipstr = inet_ntoa(a);
	boost::string_ref ipref(ipstr);

	attach_ip(info, ipref);
}

void DHCPProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	setHeader(flow->packet->getPayload());
	uint8_t msgtype = getType();
	int length = flow->packet->getLength();
	total_bytes_ += length;

	current_flow_ = flow;

	++total_packets_;

	// if there is no magic, then there is no request
	if ((length > header_size)and(header_->magic[0] == 0x63)and(header_->magic[1] == 0x82)and
		(header_->magic[2] == 0x53)and(header_->magic[3] == 0x63)) {

                SharedPointer<DHCPInfo> info = flow->getDHCPInfo();
                if (!info) {
                        if (info = info_cache_->acquire(); !info) {
				logFailCache(info_cache_->name(), flow);
				return;
                        }
                        flow->layer7info = info;
                }

		int options_length = length - header_size;
		const uint8_t *optpayload = &header_->opt[0];

		short otype = optpayload[0];
		if (otype == 53) { // Extract the dhcp message type
			short type = optpayload[2];

			if (type == DHCPDISCOVER) {
				++total_dhcp_discover_;
			} else if (type == DHCPOFFER) {
				++total_dhcp_offer_;
				// Extract the IP
				handle_ip_address(info.get());
			} else if (type == DHCPREQUEST) {
				++total_dhcp_request_;
			} else if (type == DHCPDECLINE) {
				++total_dhcp_decline_;
			} else if (type == DHCPACK) {
				++total_dhcp_ack_;
			} else if (type == DHCPNAK) {
				++total_dhcp_nak_;
			} else if (type == DHCPRELEASE) {
				++total_dhcp_release_;
			} else if (type == DHCPINFORM) {
				++total_dhcp_inform_;
                	}
		}

		if (msgtype == DHCP_BOOT_REQUEST)
			handle_request(info.get(), optpayload, options_length);
		else
			handle_reply(info.get(), optpayload, options_length);

	} else {
		// Malformed DHCP packet
                if (flow->getPacketAnomaly() == PacketAnomalyType::NONE)
                	flow->setPacketAnomaly(PacketAnomalyType::DHCP_BOGUS_HEADER);

                anomaly_->incAnomaly(PacketAnomalyType::DHCP_BOGUS_HEADER);
	}
}

void DHCPProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 3) {
		out << "\t" << "Total discovers:        " << std::setw(10) << total_dhcp_discover_ << "\n"
			<< "\t" << "Total offers:           " << std::setw(10) << total_dhcp_offer_ << "\n"
			<< "\t" << "Total requests:         " << std::setw(10) << total_dhcp_request_ << "\n"
			<< "\t" << "Total declines:         " << std::setw(10) << total_dhcp_decline_ << "\n"
			<< "\t" << "Total acks:             " << std::setw(10) << total_dhcp_ack_ << "\n"
			<< "\t" << "Total naks:             " << std::setw(10) << total_dhcp_nak_ << "\n"
			<< "\t" << "Total releases:         " << std::setw(10) << total_dhcp_release_ << "\n"
			<< "\t" << "Total informs:          " << std::setw(10) << total_dhcp_inform_ << std::endl;
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		host_cache_->statistics(out);
		ip_cache_->statistics(out);
		if (level > 4) {
			host_map_.show(out, "\t", limit);
			ip_map_.show(out, "\t", limit);
		}
	}
}

void DHCPProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

	if (level > 3) {
		Json j;

		j["discovers"] = total_dhcp_discover_;
		j["offers"] = total_dhcp_offer_;
		j["requests"] = total_dhcp_request_;
		j["declines"] = total_dhcp_decline_;
		j["acks"] = total_dhcp_ack_;
		j["nacks"] = total_dhcp_nak_;
		j["releases"] = total_dhcp_release_;
		j["informs"] = total_dhcp_inform_;

		out["types"] = j;
        }
}

void DHCPProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
        host_cache_->create(value);
        ip_cache_->create(value);
}

void DHCPProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
        host_cache_->destroy(value);
        ip_cache_->destroy(value);
}

CounterMap DHCPProtocol::getCounters() const {
	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("discovers", total_dhcp_discover_);
        cm.addKeyValue("offers", total_dhcp_offer_);
        cm.addKeyValue("requests", total_dhcp_request_);
        cm.addKeyValue("declines", total_dhcp_decline_);
        cm.addKeyValue("acks", total_dhcp_ack_);
        cm.addKeyValue("naks", total_dhcp_nak_);
        cm.addKeyValue("releases", total_dhcp_release_);
        cm.addKeyValue("informs", total_dhcp_inform_);

        return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict DHCPProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE DHCPProtocol::getCache(const std::string &name) const {
#endif
        if (boost::iequals(name, "name")or(boost::iequals(name, "host")))
		return addMapToHash(host_map_);
        else if (boost::iequals(name, "ip"))
		return addMapToHash(ip_map_);

	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> DHCPProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "name")or(boost::iequals(name, "host")))
                return host_cache_;
        else if (boost::iequals(name, "ip"))
                return ip_cache_;

        return nullptr;
}

#endif

#endif

void DHCPProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "names")) {
                for (auto &item: host_map_)
                        out.emplace(item.first, item.second.hits);
		return;
        }

	if (boost::iequals(map_name, "ips")) {
                for (auto &item: ip_map_)
                        out.emplace(item.first, item.second.hits);
        }
}

void DHCPProtocol::resetCounters() {

	reset();

        total_dhcp_discover_ = 0;
        total_dhcp_offer_ = 0;
        total_dhcp_request_ = 0;
        total_dhcp_decline_ = 0;
        total_dhcp_ack_ = 0;
        total_dhcp_nak_ = 0;
        total_dhcp_release_ = 0;
        total_dhcp_inform_ = 0;
}

} // namespace aiengine
